package com.jstarcraft.ai.jsat.classifiers;

import java.util.Collections;
import java.util.List;

import com.jstarcraft.ai.jsat.classifiers.calibration.BinaryScoreClassifier;
import com.jstarcraft.ai.jsat.parameters.Parameter;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.regression.Regressor;

/**
 * This meta algorithm wraps a {@link Regressor} to perform binary
 * classification. This is done my labeling class 0 data points as "-1" and
 * class 1 points as "1". The sign of the outputs then determines the class. Not
 * all regression algorithms will work well in this setting, and standard
 * parameter values need to change. <br>
 * The parameter values returned are exactly those provided by the given
 * regressor, or an empty list if the regressor does not implement
 * {@link Parameterized}
 * 
 * @author Edward Raff
 */
public class RegressorToClassifier implements BinaryScoreClassifier, Parameterized {

    private static final long serialVersionUID = -2607433019826385335L;
    private Regressor regressor;

    /**
     * Creates a new Binary Classifier by using the given regressor
     * 
     * @param regressor the regressor to wrap as a binary classifier
     */
    public RegressorToClassifier(Regressor regressor) {
        this.regressor = regressor;
    }

    @Override
    public double getScore(DataPoint dp) {
        return regressor.regress(dp);
    }

    @Override
    public RegressorToClassifier clone() {
        return new RegressorToClassifier(regressor.clone());
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (getScore(data) > 0)
            cr.setProb(1, 1.0);
        else
            cr.setProb(0, 1.0);

        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        RegressionDataSet rds = getRegressionDataSet(dataSet);
        regressor.train(rds, parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        RegressionDataSet rds = getRegressionDataSet(dataSet);
        regressor.train(rds);
    }

    @Override
    public boolean supportsWeightedData() {
        return regressor.supportsWeightedData();
    }

    private RegressionDataSet getRegressionDataSet(ClassificationDataSet dataSet) {
        RegressionDataSet rds = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
        for (int i = 0; i < dataSet.size(); i++)
            rds.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i) * 2 - 1);
        return rds;
    }

    @Override
    public List<Parameter> getParameters() {
        if (regressor instanceof Parameterized)
            return ((Parameterized) regressor).getParameters();
        else
            return Collections.EMPTY_LIST;
    }

    @Override
    public Parameter getParameter(String paramName) {
        if (regressor instanceof Parameterized)
            return ((Parameterized) regressor).getParameter(paramName);
        else
            return null;
    }

}
